# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the 'License');
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an 'AS IS' BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
 This transformation rule tries to identify the PRelu structure generated by
 Keras, and convert it to a single op.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import tensor_util

from tensorflowjs.converters import graph_rewrite_util

def fuse_ops_for_prelu(input_graph_def):
  """Modifies the provided graph by fusing a set of ops into a single Prelu op.
  The formula of PReLU is:
  f(x) = alpha * x for x < 0, f(x) = x for x >= 0.

  `x` is the input, and `alpha` is a trainable tensor which can be broadcasted
  to the shape of `x`.

  There's no native PRelu op in TensorFlow, so Keras generates the following
  structure which does the equivalent calculation:
  f(x) = Relu(x) + (-alpha * Relu(-x))

  Practically, alpha is always a constant in the inference graph, and grappler
  can have other graph transformations which fold the activation functions to
  other ops. Therefore, we're looking for the structure:

  f(x) = Relu(x) + (negative_alpha * Neg(x, activation=Relu))

  Args:
    input_graph_def: A GraphDef containing a model.

  Returns:
    Modified graph with Prelu ops generated, and modified weights.

  Raises:
    ValueError: If the graph is badly formed with duplicate node names.
  """
  input_node_map = {}
  for node in input_graph_def.node:
    if node.name not in input_node_map:
      input_node_map[node.name] = node
    else:
      raise ValueError('Duplicate node names detected for ', node.name)

  nodes_to_skip = {}
  inputs_to_remove = []
  updated_alpha = []
  for node in input_graph_def.node:
    if (node.op not in ('Add', 'AddV2') or len(node.input) != 2):
      continue

    relu_input_op = graph_rewrite_util.node_from_map(
        input_node_map, node.input[0])
    if (not relu_input_op or relu_input_op.op != 'Relu'):
      continue

    mul_op = graph_rewrite_util.node_from_map(input_node_map, node.input[1])
    if (not mul_op or mul_op.op != 'Mul'):
      continue

    neg_alpha_op = None
    for name in mul_op.input:
      op = graph_rewrite_util.node_from_map(input_node_map, name)
      if op.op == 'Const':
        neg_alpha_op = op
        break

    if not neg_alpha_op:
      continue

    alpha_tensor_name = neg_alpha_op.name
    _create_alpha_node(neg_alpha_op, updated_alpha)

    relu_neg_input_op = None
    for name in mul_op.input:
      op = graph_rewrite_util.node_from_map(input_node_map, name)
      if op.op == 'Relu':
        relu_neg_input_op = op
        break

    if (not relu_neg_input_op or len(relu_neg_input_op.input) != 1 or
        relu_neg_input_op.op != 'Relu'):
      continue

    # This detects a Neg op followed by a separated Relu op.
    neg_input_op = graph_rewrite_util.node_from_map(
        input_node_map, relu_neg_input_op.input[0])
    if (not neg_input_op or len(neg_input_op.input) != 1 or
        neg_input_op.op != 'Neg'):
      continue
    final_input_op = neg_input_op

    if relu_input_op.input[0] != final_input_op.input[0]:
      continue

    relu_input_op.op = 'Prelu'
    relu_input_op.input.extend([alpha_tensor_name])
    # Remove the T attr that is defined in Relu op, since our custom Prelu op
    # definition does not have that.
    del relu_input_op.attr['T']

    node.op = 'Identity'
    del node.input[:]
    node.input.append(relu_input_op.name)

    nodes_to_skip[mul_op.name] = True
    nodes_to_skip[relu_neg_input_op.name] = True
    nodes_to_skip[neg_input_op.name] = True
    nodes_to_skip[node.name] = True
    inputs_to_remove.append(node)

  return graph_rewrite_util.cleanup_graph_def(
      input_graph_def, nodes_to_skip, inputs_to_remove)

def _create_alpha_node(neg_alpha_op, updated_alpha):
  if neg_alpha_op.name not in updated_alpha:
    alpha_value = -graph_rewrite_util.values_from_const(neg_alpha_op)
    neg_alpha_op.attr['value'].CopyFrom(
        attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(
            alpha_value, alpha_value.dtype.type, alpha_value.shape)))
    updated_alpha.append(neg_alpha_op.name)

def fuse_prelu_with_fused_conv2d_or_matmul(input_graph_def):
  """Tensorflow does not support Prelu op, and the grappler remap optimizer
  will not fuse the prelu op with _FusedConv2D op. This method searches for
  the pattern and fuse the (_FusedConv2D||FusedDepthwiseConv2dNative + Prelu)
  nodes into a single _FusedConv2D||FusedDepthwiseConv2dNative op with
  activation information.

  Args:
    input_graph_def: A GraphDef containing a model.

  Returns:
    Modified graph with Prelu ops fused with _FusedConv2D or
    FusedDepthwiseConv2dNative as activation function

  Raises:
    ValueError: If the graph is badly formed with duplicate node names.
  """
  input_node_map = {}
  nodes_to_skip = {}
  inputs_to_remove = []
  for node in input_graph_def.node:
    if node.name not in input_node_map:
      input_node_map[node.name] = node
    else:
      raise ValueError('Duplicate node names detected for ', node.name)

  for node in input_graph_def.node:
    if node.op != 'Prelu':
      continue

    fused_op = graph_rewrite_util.node_from_map(
        input_node_map, node.input[0])
    if (not fused_op or
        (fused_op.op != '_FusedConv2D'
         and fused_op.op != '_FusedMatMul'
         and fused_op.op != 'FusedDepthwiseConv2dNative') or
        len(fused_op.attr['fused_ops'].list.s) > 1):
      continue

    alpha_tensor_name = node.input[1]

    fused_op.input.extend([alpha_tensor_name])
    fused_op.attr['fused_ops'].list.s.extend([b'Prelu'])
    fused_op.attr['num_args'].i = fused_op.attr['num_args'].i + 1
    node.op = 'Identity'
    node.input[:] = [node.input[0]]
    nodes_to_skip[node.name] = True
    inputs_to_remove.append(node)

  return graph_rewrite_util.cleanup_graph_def(
      input_graph_def, nodes_to_skip, inputs_to_remove)
      